// Beware, this code is shared between the kernel and userspace.

#ifdef _KERNEL
#include <types.h>
#include <lib.h>
#include <synch.h>
#include <kern/errno.h>
#include <kern/secure.h>
#include <kern/test161.h>
#else
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <stdarg.h>
#include <errno.h>
#include <unistd.h>
#include <test161/test161.h>
#include <test161/secure.h>
#endif

// Hack for allocating userspace memory without malloc, and for
// allowing secprintf in kmalloc when we're out of memory.
#define BUFFER_SIZE 1024

static char temp_buffer[BUFFER_SIZE];

#ifndef _KERNEL
static char write_buffer[BUFFER_SIZE];
#endif

#ifdef _KERNEL
// secprintf needs to be synchronized in the kernel because multiple threads
// may be trying to secprintf at the same time.
static struct semaphore *test161_sem;
#endif

// For now, allocating just passes a reference to our static temp buffer, and
// free does nothing.
static inline void * _alloc()
{
	return temp_buffer;
}

static inline void _free(void *ptr)
{
	(void)ptr;
}

/*
 * Common success function for kernel tests. If SECRET_TESTING is defined,
 * ksecprintf will compute the hmac/sha256 hash of any message using the
 * shared secret and a random salt value. The (secure) server also knows
 * the secret and can verify the message was generated by a trusted source.
 * The salt value prevents against replay attacks.
 */
int
success(int status, const char * secret, const char * name) {
	if (status == TEST161_SUCCESS) {
		return secprintf(secret, "SUCCESS", name);
	} else {
		return secprintf(secret, "FAIL", name);
	}
}

int
partial_credit(const char *secret, const char *name, int scored, int total)
{
		char buffer[128];
		snprintf(buffer, 128, "PARTIAL CREDIT %d OF %d", scored, total);
		return secprintf(secret, buffer, name);
}
#ifndef _KERNEL

// Borrowed from parallelvm.  We need atomic console writes so our
// output doesn't get intermingled since test161 works with lines.
static
int
say(const char *fmt, ...)
{
	va_list ap;
	va_start(ap, fmt);
	vsnprintf(write_buffer, BUFFER_SIZE, fmt, ap);
	va_end(ap);
	return write(STDOUT_FILENO, write_buffer, strlen(write_buffer));
}
#endif

#ifndef SECRET_TESTING

int
snsecprintf(size_t len, char *buffer, const char *secret, const char *msg, const char *name)
{
	(void)secret;
	return snprintf(buffer, len, "%s: %s", name, msg);
}

int
secprintf(const char * secret, const char * msg, const char * name)
{
	(void)secret;

#ifdef _KERNEL
	return kprintf("\n%s: %s\n", name, msg);
#else
	return say("\n%s: %s\n", name, msg);
#endif
}

#else

static int
secprintf_common(int use_buf, size_t b_len, char *buffer,
	const char *secret, const char *msg, const char *name)
{
	char *hash, *salt, *fullmsg;
	int res;
	size_t len;

#ifdef _KERNEL
	if (test161_sem == NULL) {
		panic("test161_sem is NULL. Your kernel is missing test161_bootstrap.");
	}
	P(test161_sem);
#endif

	hash = salt = fullmsg = NULL;

	// test161 expects "name: msg"
	len = strlen(name) + strlen(msg) + 3;	// +3 for " :" and null terminator
	fullmsg = (char *)_alloc(len);
	if (fullmsg == NULL) {
		res = -ENOMEM;
		goto out;
	}
	snprintf(fullmsg, len, "%s: %s", name, msg);

	res = hmac_salted(fullmsg, len-1, secret, strlen(secret), &hash, &salt);
	if (res) {
		res = -res;
		goto out;
	}

	if (!use_buf) {
#ifdef _KERNEL
		res = kprintf("\n(%s, %s, %s, %s: %s)\n", name, hash, salt, name, msg);
#else
		res = say("\n(%s, %s, %s, %s: %s)\n", name, hash, salt, name, msg);
#endif
	} else {
		res = snprintf(buffer, b_len, "\n(%s, %s, %s, %s: %s)\n", name, hash, salt, name, msg);
	}

out:
	// These may be NULL, but that's OK
	_free(hash);
	_free(salt);
	_free(fullmsg);

#ifdef _KERNEL
	V(test161_sem);
#endif

	return res;
}

int
snsecprintf(size_t b_len, char *buffer, const char *secret, const char *msg, const char *name)
{
	return secprintf_common(1, b_len, buffer, secret, msg, name);
}

int
secprintf(const char * secret, const char * msg, const char * name)
{
	return secprintf_common(0, 0, NULL, secret, msg, name);
}

#endif

#ifdef _KERNEL
void test161_bootstrap()
{
	test161_sem = sem_create("test161", 1);
	if (test161_sem == NULL) {
		panic("Failed to create test161 secprintf semaphore");
	}
}
#endif
